{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experimental notebook that modifies basic API a little to prefetch data-loaders and calculate loss multi-gpu (not cuda:0). However, no real improvement in speed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "MULTI_GPU = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torchvision.models as models\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.nn.init as init\n",
    "import torchvision.transforms as transforms\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from sklearn.metrics.ranking import roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from PIL import Image\n",
    "from common.utils import download_data_chextxray, get_imgloc_labels, get_train_valid_test_split\n",
    "from common.utils import compute_roc_auc, get_cuda_version, get_cudnn_version, get_gpu_name\n",
    "from common.utils import yield_mb\n",
    "from common.params_dense import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) \n",
      "[GCC 7.2.0]\n",
      "PyTorch:  0.4.0\n",
      "Numpy:  1.14.1\n",
      "GPU:  ['Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB']\n",
      "CUDA Version 9.0.176\n",
      "CuDNN Version  7.0.5\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"PyTorch: \", torch.__version__)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPUs:  24\n",
      "GPUs:  4\n"
     ]
    }
   ],
   "source": [
    "CPU_COUNT = multiprocessing.cpu_count()\n",
    "GPU_COUNT = len(get_gpu_name())\n",
    "print(\"CPUs: \", CPU_COUNT)\n",
    "print(\"GPUs: \", GPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chestxray/images chestxray/Data_Entry_2017.csv\n"
     ]
    }
   ],
   "source": [
    "# Model-params\n",
    "IMAGENET_RGB_MEAN_TORCH = [0.485, 0.456, 0.406]\n",
    "IMAGENET_RGB_SD_TORCH = [0.229, 0.224, 0.225]\n",
    "# Paths\n",
    "CSV_DEST = \"chestxray\"\n",
    "IMAGE_FOLDER = os.path.join(CSV_DEST, \"images\")\n",
    "LABEL_FILE = os.path.join(CSV_DEST, \"Data_Entry_2017.csv\")\n",
    "print(IMAGE_FOLDER, LABEL_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually scale to multi-gpu\n",
    "assert torch.cuda.is_available()\n",
    "_DEVICE = torch.device(\"cuda:0\")\n",
    "# enables cudnn's auto-tuner\n",
    "torch.backends.cudnn.benchmark=True\n",
    "if MULTI_GPU:\n",
    "    LR *= GPU_COUNT \n",
    "    BATCHSIZE *= GPU_COUNT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please make sure to download\n",
      "https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\n",
      "Data already exists\n",
      "CPU times: user 580 ms, sys: 240 ms, total: 819 ms\n",
      "Wall time: 819 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Download data\n",
    "# Wall time: 17min 58s\n",
    "print(\"Please make sure to download\")\n",
    "print(\"https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\")\n",
    "download_data_chextxray(CSV_DEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalise by imagenet mean/sd\n",
    "normalize = transforms.Normalize(IMAGENET_RGB_MEAN_TORCH,\n",
    "                                 IMAGENET_RGB_SD_TORCH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XrayData(Dataset):\n",
    "    def __init__(self, img_dir, lbl_file, patient_ids, transform=None):\n",
    "        \n",
    "        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)\n",
    "        self.transform = transform\n",
    "        print(\"Loaded {} labels and {} images\".format(len(self.labels), len(self.img_locs)))\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        im_file = self.img_locs[idx]\n",
    "        im_rgb = Image.open(im_file)\n",
    "        label = self.labels[idx]\n",
    "        if self.transform is not None:\n",
    "            im_rgb = self.transform(im_rgb)\n",
    "        return im_rgb, torch.FloatTensor(label)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.img_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_augmentation_dataset(img_dir, lbl_file, patient_ids, normalize):\n",
    "    dataset = XrayData(img_dir, lbl_file, patient_ids,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.Resize(WIDTH),\n",
    "                           transforms.ToTensor(),  \n",
    "                           normalize]))\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train:21563 valid:3080 test:6162\n"
     ]
    }
   ],
   "source": [
    "train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 87306 labels and 87306 images\n"
     ]
    }
   ],
   "source": [
    "# Dataset for training\n",
    "train_dataset = XrayData(img_dir=IMAGE_FOLDER,\n",
    "                         lbl_file=LABEL_FILE,\n",
    "                         patient_ids=train_set,\n",
    "                         transform=transforms.Compose([\n",
    "                             transforms.RandomResizedCrop(size=WIDTH),\n",
    "                             transforms.RandomHorizontalFlip(),\n",
    "                             transforms.ToTensor(),  # need to convert image to tensor!\n",
    "                             normalize]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 7616 labels and 7616 images\n",
      "Loaded 17198 labels and 17198 images\n"
     ]
    }
   ],
   "source": [
    "valid_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, valid_set, normalize)\n",
    "test_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, test_set, normalize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataParallelNoGather(torch.nn.DataParallel):\n",
    "    def gather(self, outputs, output_device):\n",
    "        return outputs  # no concat to output-device here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataParallelCriterion(torch.nn.DataParallel):\n",
    "    def forward(self, inputs, *targets, **kwargs):\n",
    "            if not self.device_ids:\n",
    "                return self.module(inputs, *targets, **kwargs)\n",
    "            # Since output not gathered, scatter targets for multi-gpu loss cal\n",
    "            targets, kwargs = self.scatter(targets, kwargs, self.device_ids)\n",
    "            if len(self.device_ids) == 1:\n",
    "                return self.module(inputs, *targets[0], **kwargs[0])\n",
    "            # Return a list of losses on each gpu\n",
    "            return [self.module(inputs[i], *targets[i], **kwargs[i]) for i in range(len(inputs))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_symbol(out_features=CLASSES, multi_gpu=MULTI_GPU):\n",
    "    model = models.densenet.densenet121(pretrained=True)\n",
    "    # Replace classifier (FC-1000) with (FC-14)\n",
    "    model.classifier = nn.Sequential(\n",
    "        nn.Linear(model.classifier.in_features, out_features), \n",
    "        nn.Sigmoid())\n",
    "    if multi_gpu:\n",
    "        model = DataParallelNoGather(model)\n",
    "        #model = nn.DataParallel(model)\n",
    "    # CUDA\n",
    "    model.to(_DEVICE)  \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_symbol(sym, lr=LR):\n",
    "    # BCE Loss since classes not mutually exclusive + Sigmoid FC-layer\n",
    "    cri = nn.BCELoss()\n",
    "    opt = optim.Adam(sym.parameters(), lr=lr, betas=(0.9, 0.999))\n",
    "    sch = ReduceLROnPlateau(opt, factor=0.1, patience=5, mode='min')\n",
    "    return opt, cri, sch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, dataloader, optimizer, criterion):\n",
    "    model.train()\n",
    "    print(\"Training epoch\")\n",
    "    \n",
    "    # Accumulate loss on gpu to avoid cpu-gpu comms\n",
    "    # Multi-gpu tensors to avoid gpu to cuda:0 comms\n",
    "    loss_val = [torch.FloatTensor(1).fill_(0).cuda(i) for i in range(GPU_COUNT)]\n",
    "    \n",
    "    for i, (data, target) in enumerate(dataloader): \n",
    "        # Get samples (both async)\n",
    "        data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n",
    "        # Forwards (modified to return ungathered prediction), so output is a list\n",
    "        outputs = model(data)\n",
    "        # Losses (list) for each gpu\n",
    "        all_losses = criterion(outputs, target)\n",
    "        # Back-prop\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Log the loss (before .backward())\n",
    "        for j, l in enumerate(all_losses):\n",
    "            # Not calling .item() which is blocking\n",
    "            loss_val[j] += l      \n",
    "        \n",
    "        # NOTE: this is much more efficient than calling .backward() on losses in a for loop\n",
    "        torch.autograd.backward(all_losses)\n",
    "        optimizer.step()  \n",
    "        \n",
    "    avg_loss = (sum([loss_val[i].detach().cpu() for i in range(GPU_COUNT)]).numpy()/GPU_COUNT/(i+1))[0]\n",
    "    print(\"Training loss: {0:.4f}\".format(avg_loss))\n",
    "    print(\"~~~~~~~\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid_epoch(model, dataloader, criterion, phase='valid', cl=CLASSES):\n",
    "    model.eval()\n",
    "    if phase == 'testing':\n",
    "        print(\"Testing epoch\")\n",
    "    else:\n",
    "        print(\"Validating epoch\")\n",
    "    # Don't save gradients\n",
    "    with torch.no_grad():\n",
    "        if phase == 'testing':\n",
    "            # pre-allocate predictions (on gpu)\n",
    "            len_pred = len(dataloader)*(dataloader.batch_size)\n",
    "            num_lab = dataloader.dataset.labels.shape[-1]\n",
    "            out_pred = torch.cuda.FloatTensor(len_pred, num_lab).fill_(0)\n",
    "        # Accumulate loss on gpu to avoid cpu-gpu comms\n",
    "        loss_val = [torch.FloatTensor(1).fill_(0).cuda(i) for i in range(GPU_COUNT)]\n",
    "        for i, (data, target) in enumerate(dataloader): \n",
    "            # Get samples\n",
    "            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n",
    "             # Forwards\n",
    "            outputs = model(data)\n",
    "            # Loss\n",
    "            all_losses = criterion(outputs, target)\n",
    "            # Log the loss\n",
    "            for j, l in enumerate(all_losses):\n",
    "                # Not calling .item() which is blocking\n",
    "                loss_val[j] += l \n",
    "            # Log for AUC\n",
    "            if phase == 'testing':\n",
    "                output = torch.cat(outputs)\n",
    "                out_pred[output.size(0)*i:output.size(0)*(1+i)] = output.data\n",
    "        # Final loss\n",
    "        avg_loss = (sum([loss_val[i].detach().cpu() for i in range(GPU_COUNT)]).numpy()/GPU_COUNT/(i+1))[0]\n",
    "    if phase == 'testing':\n",
    "        out_gt = dataloader.dataset.labels\n",
    "        out_pred = out_pred.cpu().numpy()[:len(out_gt)]  # Trim padding\n",
    "        print(\"Test-Dataset loss: {0:.4f}\".format(avg_loss))\n",
    "        print(\"Test-Dataset AUC: {0:.4f}\".format(compute_roc_auc(out_gt, out_pred, cl)))\n",
    "    else:\n",
    "        print(\"Validation loss: {0:.4f}\".format(avg_loss))\n",
    "    return avg_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimal to use fewer workers than CPU_COUNT\n",
    "# DataLoaders\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=True, num_workers=6, pin_memory=True)\n",
    "# Using a bigger batch-size (than BATCHSIZE) for below worsens performance\n",
    "valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=False, num_workers=6, pin_memory=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=BATCHSIZE,\n",
    "                         shuffle=False, num_workers=6, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Train CheXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/torchvision-0.2.1-py3.5.egg/torchvision/models/densenet.py:212: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.23 s, sys: 1.47 s, total: 5.7 s\n",
      "Wall time: 6.09 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load symbol\n",
    "chexnet_sym = get_symbol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.81 ms, sys: 0 ns, total: 1.81 ms\n",
      "Wall time: 1.81 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load optimiser, loss\n",
    "# Scheduler for LRPlateau is not used\n",
    "optimizer, criterion, scheduler = init_symbol(chexnet_sym)\n",
    "# Calculate loss on all GPUs\n",
    "criterion = DataParallelCriterion(criterion, chexnet_sym.device_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training epoch\n",
      "Training loss: 0.1741\n",
      "~~~~~~~\n",
      "Validating epoch\n",
      "Validation loss: 0.1472\n",
      "Epoch time: 135 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1590\n",
      "~~~~~~~\n",
      "Validating epoch\n",
      "Validation loss: 0.1444\n",
      "Epoch time: 108 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1558\n",
      "~~~~~~~\n",
      "Validating epoch\n",
      "Validation loss: 0.1442\n",
      "Epoch time: 111 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1541\n",
      "~~~~~~~\n",
      "Validating epoch\n",
      "Validation loss: 0.1429\n",
      "Epoch time: 107 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1527\n",
      "~~~~~~~\n",
      "Validating epoch\n",
      "Validation loss: 0.1410\n",
      "Epoch time: 106 seconds\n",
      "CPU times: user 13min 37s, sys: 2min 51s, total: 16min 29s\n",
      "Wall time: 9min 28s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 4 GPU - Main training loop: 9min 28s\n",
    "# Main train/val loop\n",
    "train_iter = iter(train_loader) # Prefetch some training data in the background\n",
    "\n",
    "for j in range(EPOCHS):\n",
    "    stime = time.time()\n",
    "    valid_iter = iter(valid_loader) # Will start fetching validation data \n",
    "    train_epoch(chexnet_sym, train_iter, optimizer, criterion)\n",
    "    train_iter = iter(train_loader) # Will prefetch some training data\n",
    "    loss_val = valid_epoch(chexnet_sym, valid_iter, criterion)   \n",
    "    print(\"Epoch time: {0:.0f} seconds\".format(time.time()-stime))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing epoch\n",
      "Test-Dataset loss: 0.1538\n",
      "Full AUC [0.8180020240235686, 0.8601479720224297, 0.7969336459283853, 0.89202362445056, 0.8854655649319612, 0.9015711126353438, 0.721529667743366, 0.8872964038058377, 0.624164759241935, 0.8489753774960405, 0.7411880319380078, 0.7823526644862778, 0.742070371621199, 0.8745168006955891]\n",
      "Test-Dataset AUC: 0.8126\n",
      "CPU times: user 14.2 s, sys: 5.61 s, total: 19.8 s\n",
      "Wall time: 1min 21s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 4 GPU AUC: 0.8126\n",
    "test_loss = valid_epoch(chexnet_sym, test_loader, criterion, 'testing')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Synthetic Data (Pure Training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87296\n"
     ]
    }
   ],
   "source": [
    "# Test on fake-data -> no IO lag\n",
    "batch_in_epoch = len(train_dataset.labels)//BATCHSIZE\n",
    "tot_num = batch_in_epoch * BATCHSIZE\n",
    "print(tot_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "fake_X = torch.tensor(np.random.rand(tot_num, 3, 224, 224).astype(np.float32))\n",
    "fake_y = torch.tensor(np.random.rand(tot_num, CLASSES).astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training epoch\n",
      "Training loss: 0.6935\n",
      "~~~~~~~\n",
      "Training epoch\n",
      "Training loss: 0.6935\n",
      "~~~~~~~\n",
      "Training epoch\n",
      "Training loss: 0.6934\n",
      "~~~~~~~\n",
      "Training epoch\n",
      "Training loss: 0.6934\n",
      "~~~~~~~\n",
      "Training epoch\n",
      "Training loss: 0.6935\n",
      "~~~~~~~\n",
      "CPU times: user 12min 27s, sys: 1min 32s, total: 13min 59s\n",
      "Wall time: 8min 45s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 4 GPU - Synthetic data: 8min 45s\n",
    "for j in range(EPOCHS):\n",
    "    train_epoch(chexnet_sym, \n",
    "                yield_mb(fake_X, fake_y, BATCHSIZE, shuffle=False),\n",
    "                optimizer, \n",
    "                criterion)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py35]",
   "language": "python",
   "name": "conda-env-py35-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
